package com.study.crypto.basic.digest;

import com.study.crypto.basic.utils.KeyUtils;
import org.bouncycastle.asn1.x509.Certificate;
import org.bouncycastle.crypto.Digest;
import org.bouncycastle.crypto.digests.SM3Digest;
import org.bouncycastle.math.ec.ECCurve;
import org.bouncycastle.math.ec.ECFieldElement;
import org.bouncycastle.math.ec.ECPoint;

/**
 * 带公钥计算 sm3 摘要
 *
 * @author Songjin
 * @since 2022-03-28 15:39
 */
public class SM3PublicKeyDigest extends SM3Digest {
    
    private final ECPoint publicPoint;
    
    /**
     * 用户证书
     * @param certificate 证书
     */
    public SM3PublicKeyDigest(Certificate certificate) {
        byte[] publicKeyBytes = certificate.getSubjectPublicKeyInfo().getPublicKeyData().getBytes();
        publicPoint = KeyUtils.sm2p256v1.getCurve().decodePoint(publicKeyBytes);
        byte[] z = this.getZ("1234567812345678".getBytes());
        this.update(z, 0, z.length);
    }
    
    /**
     * 用户公钥
     * @param publicKey 公钥字节数据，65字节
     */
    public SM3PublicKeyDigest(byte[] publicKey) {
        publicPoint = KeyUtils.sm2p256v1.getCurve().decodePoint(publicKey);
        byte[] z = this.getZ("1234567812345678".getBytes());
        this.update(z, 0, z.length);
    }
    
    private byte[] getZ(byte[] userID) {
        this.reset();
        
        SM3Digest digest = new SM3Digest();
        addUserID(digest, userID);
    
        ECCurve curve = KeyUtils.sm2p256v1.getCurve();
        ECPoint pointG = KeyUtils.sm2p256v1.getG();
        addFieldElement(digest, curve.getA());
        addFieldElement(digest, curve.getB());
        addFieldElement(digest, pointG.getAffineXCoord());
        addFieldElement(digest, pointG.getAffineYCoord());
        addFieldElement(digest, publicPoint.getAffineXCoord());
        addFieldElement(digest, publicPoint.getAffineYCoord());
        
        byte[] result = new byte[this.getDigestSize()];
        digest.doFinal(result, 0);
        return result;
    }
    
    private void addUserID(Digest digest, byte[] userID) {
        int len = userID.length * 8;
        digest.update((byte) (len >> 8 & 0xFF));
        digest.update((byte) (len & 0xFF));
        digest.update(userID, 0, userID.length);
    }
    
    private void addFieldElement(Digest digest, ECFieldElement v) {
        byte[] p = v.getEncoded();
        digest.update(p, 0, p.length);
    }
}
